Skip to content

LangGraph Tools 集成方案

一、模块概述

属性说明
模块名称Tools(工具调用)+ Conditional Edges(条件边)
优先级🔴 P0(最高)
预估工时2-3 天
依赖项langchain-core, langgraph

为什么需要

当前 Agent 只能进行纯对话,无法执行实际操作或获取实时信息。集成 Tools 后可以:

  • 联网搜索获取实时信息
  • 查询数据库获取业务数据
  • 调用外部 API 执行操作

二、架构设计

2.1 当前架构 vs 目标架构

当前架构(单节点线性):

START ──► [agent/call_model] ──► END

目标架构(带工具调用):

                    ┌─────────────────────────┐
                    │        START           │
                    └───────────┬─────────────┘


                    ┌─────────────────────────┐
                    │        agent            │
                    │    (LLM + 工具绑定)     │
                    └───────────┬─────────────┘

                    ┌───────────┴───────────┐
                    │                     │
              should_continue() 条件判断
                    │                     │
          ┌─────────┴─────────┐
          │                       │
          ▼                       ▼
┌─────────────────┐     ┌─────────────────┐
│     tools       │     │      END        │
│  (执行工具调用)   │     │   (无工具调用)   │
└────────┬────────┘     └─────────────────┘

         │ (执行完成后)

         └──────────────────┐


                    ┌─────────────────┐
                    │     agent       │
                    │ (处理工具结果)  │
                    └─────────────────┘

2.2 状态设计

python
from typing import TypedDict, Annotated
from langgraph.graph import MessagesState

# 使用 LangGraph 内置的 MessagesState
# 它已经包含了 messages 字段,使用 add reducer 自动合并消息

三、代码实现

3.1 工具定义

创建文件: services/tools/__init__.py

python
"""LangGraph Agent 工具集"""
from langchain.tools import tool
from typing import Optional
import logging

logger = logging.getLogger(__name__)


@tool
def search_web(query: str) -> str:
    """
    搜索网络获取实时信息。

    当用户询问天气、新闻、股价、实时数据等问题时使用此工具。

    Args:
        query: 搜索关键词

    Returns:
        搜索结果摘要
    """
    # TODO: 集成搜索 API(如 Tavily、SerpAPI、Google Custom Search)
    # 示例实现:
    try:
        # 这里可以替换为实际的搜索 API
        # from langchain_community.tools import TavilySearchResults
        # search = TavilySearchResults(max_results=3)
        # return search.invoke(query)

        # 临时返回提示信息
        return f"搜索功能待集成。搜索关键词: {query}"
    except Exception as e:
        logger.error(f"搜索失败: {e}")
        return f"搜索出错: {str(e)}"


@tool
def get_current_time(timezone: Optional[str] = None) -> str:
    """
    获取当前时间。

    当用户询问时间、日期时使用此工具。

    Args:
        timezone: 时区(可选,默认为本地时区)

    Returns:
        当前时间字符串
    """
    from datetime import datetime
    import pytz

    if timezone:
        try:
            tz = pytz.timezone(timezone)
            return datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S")
        except pytz.UnknownTimeZoneError:
            return f"未知时区: {timezone}"
    else:
        return datetime.now().strftime("%Y-%m-%d %H:%M:%S")


@tool
def calculate(expression: str) -> str:
    """
    执行数学计算。

    当用户需要进行数学运算、单位转换等时使用此工具。

    Args:
        expression: 数学表达式,如 "2 + 3 * 4" 或 "100 / 5"

    Returns:
        计算结果
    """
    try:
        # 安全地计算表达式(仅允许数学运算)
        import ast
        import operator

        allowed_operators = {
            ast.Add: operator.add,
            ast.Sub: operator.sub,
            ast.Mult: operator.mul,
            ast.Div: operator.truediv,
            ast.Pow: operator.pow,
            ast.Mod: operator.mod,
        }

        def eval_expr(node):
            if isinstance(node, ast.Num):
                return node.n
            elif isinstance(node, ast.BinOp):
                op = allowed_operators.get(type(node.op))
                if op is None:
                    raise ValueError(f"不支持的操作符: {type(node.op)}")
                return op(eval_expr(node.left), eval_expr(node.right))
            else:
                raise ValueError(f"不支持的表达式: {type(node)}")

        tree = ast.parse(expression, mode='eval')
        result = eval_expr(tree.body)
        return f"计算结果: {expression} = {result}"
    except Exception as e:
        logger.error(f"计算失败: {e}")
        return f"计算错误: {str(e)}"


# 导出所有工具
ALL_TOOLS = [search_web, get_current_time, calculate]

3.2 Agent 服务改造

修改文件: services/langgraph_agent.py

python
"""LangGraph Agent 服务 - 支持工具调用

基于 LangGraph 的 ReAct Agent 实现,支持:
- 自动工具调用
- 多轮对话
- 流式输出
- 持久化(MySQL Checkpointer)
"""
import os
import logging
from typing import Optional, Iterator, Dict, Any, List, Literal
from dataclasses import dataclass, field
from dotenv import load_dotenv

from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_openai import ChatOpenAI
from langgraph.graph import START, StateGraph, MessagesState, END
from langgraph.prebuilt import ToolNode

from services.tools import ALL_TOOLS

load_dotenv(override=True)

logger = logging.getLogger(__name__)


@dataclass
class AgentConfig:
    """Agent 配置"""
    api_key: str = field(default_factory=lambda: os.getenv("OPENROUTER_API_KEY", ""))
    base_url: str = field(default_factory=lambda: os.getenv("OPENROUTER_BASE_URL", "https://openrouter.ai/api/v1"))
    default_model: str = field(default_factory=lambda: os.getenv("OPENROUTER_MODEL", "openai/gpt-4o"))


class ToolEnabledAgent:
    """支持工具调用的 Agent"""

    def __init__(self, config: Optional[AgentConfig] = None):
        self.config = config or AgentConfig()
        self.tools = ALL_TOOLS
        self.tools_by_name = {tool.name: tool for tool in self.tools}

    def _get_llm(self, model: Optional[str] = None, temperature: float = 0.7) -> ChatOpenAI:
        """获取绑定了工具的 LLM"""
        llm = ChatOpenAI(
            model=model or self.config.default_model,
            api_key=self.config.api_key,
            base_url=self.config.base_url,
            temperature=temperature
        )
        # 绑定工具
        return llm.bind_tools(self.tools)

    def _should_continue(self, state: MessagesState) -> Literal["tools", END]:
        """判断是否需要调用工具"""
        last_message = state["messages"][-1]
        # 如果最后一条消息有工具调用,则执行工具
        if last_message.tool_calls:
            return "tools"
        # 否则结束
        return END

    def _call_model(self, state: MessagesState, model: Optional[str] = None, system_prompt: Optional[str] = None):
        """调用 LLM 节点"""
        llm = self._get_llm(model)

        # 构建系统提示
        system_msg = system_prompt or "你是一个有用的 AI 助手。你可以使用工具来帮助用户。"

        messages = [
            SystemMessage(content=system_msg),
            *state["messages"]
        ]

        response = llm.invoke(messages)
        return {"messages": [response]}

    def _build_graph(self, checkpointer, model: Optional[str] = None, system_prompt: Optional[str] = None):
        """构建带工具调用的 Graph"""
        # 创建工具节点
        tool_node = ToolNode(self.tools)

        # 构建工作流
        workflow = StateGraph(MessagesState)

        # 添加节点
        workflow.add_node(
            "agent",
            lambda state: self._call_model(state, model, system_prompt)
        )
        workflow.add_node("tools", tool_node)

        # 添加边
        workflow.add_edge(START, "agent")
        workflow.add_conditional_edges(
            "agent",
            self._should_continue,
            {"tools": "tools", END: END}
        )
        workflow.add_edge("tools", "agent")  # 工具执行后返回 agent

        return workflow.compile(checkpointer=checkpointer)

    def chat(
        self,
        thread_id: str,
        prompt: str,
        model: Optional[str] = None,
        system_prompt: Optional[str] = None,
        temperature: float = 0.7,
        **kwargs
    ) -> Dict[str, Any]:
        """同步对话"""
        from services.checkpointer import get_checkpointer

        with get_checkpointer() as checkpointer:
            app = self._build_graph(checkpointer, model, system_prompt)
            config = {"configurable": {"thread_id": thread_id}}

            result = app.invoke(
                {"messages": [HumanMessage(content=prompt)]},
                config
            )

            last_message = result["messages"][-1]
            return {"content": last_message.content}

    def chat_stream(
        self,
        thread_id: str,
        prompt: str,
        model: Optional[str] = None,
        system_prompt: Optional[str] = None,
        temperature: float = 0.7,
        **kwargs
    ) -> Iterator[str]:
        """流式对话"""
        from services.checkpointer import get_checkpointer

        with get_checkpointer() as checkpointer:
            app = self._build_graph(checkpointer, model, system_prompt)
            config = {"configurable": {"thread_id": thread_id}}

            for chunk in app.stream(
                {"messages": [HumanMessage(content=prompt)]},
                config,
                stream_mode="messages"
            ):
                if isinstance(chunk, tuple) and len(chunk) == 2:
                    message_chunk, metadata = chunk
                    if hasattr(message_chunk, 'content') and message_chunk.content:
                        yield message_chunk.content


# 全局实例
_agent: Optional[ToolEnabledAgent] = None


def get_agent() -> ToolEnabledAgent:
    """获取 Agent 单例"""
    global _agent
    if _agent is None:
        _agent = ToolEnabledAgent()
    return _agent

四、API 集成

4.1 路由修改

修改 api/chat.py:

python
# 添加新的路由端点
from services.langgraph_agent import get_agent

@router.post("/chat/agent/stream")
async def agent_chat_stream(
    data: ChatRequest,
    request: Request,
    db: Session = Depends(get_db)
):
    """带工具调用的 Agent 流式聊天"""
    session = get_session(request)
    user_id = session.get("user", {}).get("user_id")

    agent = get_agent()

    async def generate():
        full_content = ""
        try:
            for chunk in agent.chat_stream(
                thread_id=data.conversation_id or "anonymous",
                prompt=data.prompt,
                model=data.model,
                system_prompt=data.system_prompt,
            ):
                full_content += chunk
                yield f"data: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"

            yield f"data: {json.dumps({'done': True}, ensure_ascii=False)}\n\n"
        except Exception as e:
            logger.error(f"Agent chat error: {e}", exc_info=True)
            yield f"data: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"

    return StreamingResponse(
        generate(),
        media_type="text/event-stream",
        headers={
            "Cache-Control": "no-cache",
            "X-Accel-Buffering": "no",
        }
    )

五、前端集成

5.1 工具调用展示

修改 static/js/chat.js:

javascript
// 处理工具调用展示
function handleToolCall(toolCall) {
    const toolElement = document.createElement('div');
    toolElement.className = 'tool-call';

    const toolIcon = document.createElement('span');
    toolIcon.className = 'tool-icon';
    toolIcon.textContent = '🔧';

    const toolName = document.createElement('span');
    toolName.className = 'tool-name';
    toolName.textContent = toolCall.name;

    const toolArgs = document.createElement('pre');
    toolArgs.className = 'tool-args';
    toolArgs.textContent = JSON.stringify(toolCall.args, null, 2);

    toolElement.appendChild(toolIcon);
    toolElement.appendChild(toolName);
    toolElement.appendChild(toolArgs);

    return toolElement;
}

// 处理工具结果展示
function handleToolResult(result) {
    const resultElement = document.createElement('div');
    resultElement.className = 'tool-result';
    resultElement.textContent = `工具结果: ${result}`;
    return resultElement;
}

5.2 CSS 样式

添加到 static/css/index.css:

css
/* 工具调用样式 */
.tool-call {
    background: #f8f9fa;
    border-left: 3px solid #4a90d9;
    padding: 10px 15px;
    margin: 10px 0;
    border-radius: 4px;
    font-size: 14px;
}

.tool-icon {
    margin-right: 8px;
}

.tool-name {
    font-weight: bold;
    color: #4a90d9;
}

.tool-args {
    background: #fff;
    padding: 8px;
    margin-top: 8px;
    border-radius: 4px;
    font-size: 12px;
    overflow-x: auto;
}

.tool-result {
    background: #e8f5e9;
    border-left: 3px solid #4caf50;
    padding: 10px 15px;
    margin: 10px 0;
    border-radius: 4px;
    font-size: 14px;
}

六、测试计划

6.1 单元测试

python
# tests/test_tools.py
import pytest
from services.tools import search_web, get_current_time, calculate


def test_get_current_time():
    """测试获取时间工具"""
    result = get_current_time.invoke({"timezone": None})
    assert result  # 应返回时间字符串

    result = get_current_time.invoke({"timezone": "Asia/Shanghai"})
    assert "Asia/Shanghai" not in result  # 正常时区应返回时间


def test_calculate():
    """测试计算工具"""
    assert "6" in calculate.invoke({"expression": "2 + 2 * 2"})
    assert "20" in calculate.invoke({"expression": "100 / 5"})
    assert "错误" in calculate.invoke({"expression": "invalid"})


def test_search_web():
    """测试搜索工具"""
    result = search_web.invoke({"query": "test"})
    assert "搜索" in result

6.2 集成测试

python
# tests/test_agent.py
import pytest
from services.langgraph_agent import get_agent


def test_agent_basic_chat():
    """测试基本对话"""
    agent = get_agent()
    result = agent.chat(
        thread_id="test-1",
        prompt="你好"
    )
    assert result["content"]


def test_agent_tool_call():
    """测试工具调用"""
    agent = get_agent()
    result = agent.chat(
        thread_id="test-2",
        prompt="现在几点了?"
    )
    assert result["content"]
    # 应该调用了 get_current_time 工具

七、实施步骤

步骤 1: 创建工具定义(0.5 天)

  • 创建 services/tools/__init__.py
  • 实现基础工具(search_web, get_current_time, calculate)
  • 编写工具单元测试

步骤 2: 创建 Agent 服务(1 天)

  • 创建 services/langgraph_agent.py
  • 实现 _should_continue 条件边
  • 实现 _call_model 和工具节点
  • 编写 Agent 单元测试

步骤 3: API 集成(0.5 天)

  • 添加 /api/chat/agent/stream 端点
  • 更新 API 文档

步骤 4: 前端集成(0.5 天)

  • 添加工具调用展示组件
  • 添加 CSS 样式
  • 测试完整流程

步骤 5: 测试和优化(0.5 天)

  • 端到端测试
  • 性能测试
  • 修复问题

八、相关文档